using ArgParse, JLD2, Printf, JSON, Dates, IterTools, Random;
using Distributed;

@everywhere include("runit.jl");
@everywhere include("helpers.jl");
@everywhere include("../binary_search.jl");
include("helpers_experiments.jl");

function parse_commandline()
    s = ArgParseSettings();

    @add_arg_table! s begin
        "--save_dir"
            help = "Directory for saving the experiment's data."
            arg_type = String
            default = "experiments/"
        "--data_dir"
            help = "Directory for loading the data."
            arg_type = String
            default = "data/"
        "--seed"
            help = "Seed."
            arg_type = Int64
            default = 42
        "--inst"
            help = "Instance considered."
            arg_type = String
            default = "ber_eq3rd"
        "--K"
            help = "Number of arms."
            arg_type = Int64
            default = 4
        "--B"
            help = "Upper bound."
            arg_type = Float64
            default = 1.0
        "--expe"
            help = "Experiment considered."
            arg_type = String
            default = "test"
        "--Nruns"
            help = "Number of runs of the experiment."
            arg_type = Int64
            default = 8
        "--wdeltas"
            help = "Run experiments with extended values for delta."
            action = :store_true
    end

    parse_args(s);
end

# Parameters
parsed_args = parse_commandline();
save_dir = parsed_args["save_dir"];
data_dir = parsed_args["data_dir"];
seed = parsed_args["seed"];
inst = parsed_args["inst"];
nK = occursin("dsat", inst) ? parse(Int64, split(inst, "_")[2]) : parsed_args["K"];
expe = parsed_args["expe"];
Nruns = parsed_args["Nruns"];
wdeltas = parsed_args["wdeltas"];

# Storing parameters defining the instance
param_inst = Dict("inst" => inst, "nK" => nK, "data_dir" => data_dir);

# Associated β functions
δs = wdeltas ? [0.1, 0.01, 0.001] : [0.01];

# Naming files and folder
name_data = inst * "_K" * string(nK);
data_file = data_dir * name_data * ".dat";
now_str = Dates.format(now(), "dd-mm_HHhMM");
experiment_name = "exp_" * name_data * "_" * expe * (wdeltas ? "_delta" : "") * "_N" * string(Nruns);
experiment_dir = save_dir * now_str * ":" * experiment_name * "/";
mkdir(experiment_dir);
open("$(experiment_dir)parsed_args.json","w") do f
    JSON.print(f, parsed_args)
end

# For reproducibility, load the data if already defined.
if isfile(data_file)
    @load data_file dists μs B Tstar wstar param_inst;
else
    @warn "Generating new data.";

    # Parameters
    μs, dists, B = get_instance_experiment(param_inst);

    # Oracle
    n = 25000;
    rng = MersenneTwister(seed);
    Xs = [[sample(rng, dist) for i in 1:n] for dist in dists];
    μs_emp = mean.(Xs);
    pep = BestArm(dists, B);
    Tstar, wstar = oracle(pep, μs_emp, Xs);

    @save data_file dists μs B Tstar wstar param_inst;
end
@save "$(experiment_dir)$(name_data).dat" data_file dists μs B Tstar wstar param_inst;

# Get Tau_max
min_δ = minimum(δs);
lbd = (1 - 2 * min_δ) * log((1 - min_δ) / min_δ) * Tstar;
timeout_factor = occursin("GK16", expe) ? 15 : 60;
Tau_max = occursin("fail_equ", expe) ? timeout_factor * lbd : 1e8;

# Pure exploration problem
pep = BestArm(dists, B);

# Identification strategy used used on this instance: tuple (sr, rsp)
iss = everybody(expe, wstar);

# Run the experiments in parallel
@time data = pmap(
    ((is, i),) -> runit(seed + i, is, pep, δs, Tau_max),
    Iterators.product(iss, 1:Nruns)
);

# Save everything using JLD2.
@save "$(experiment_dir)$(experiment_name).dat" dists μs Tstar wstar pep iss data δs Nruns seed;

# Print a summary of the problem we considered
file = "$(experiment_dir)summary_$(experiment_name).txt";
print_summary(pep, dists, μs, Tstar, wstar, δs, iss, data, Nruns, file);
